{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "htW5SiGzeXYm"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Probability Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "9HGeUNoteaSm"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JJ3UDciDVcB5"
      },
      "source": [
        "# Bayesian Gaussian Mixture Model and Hamiltonian MCMC\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/probability/examples/Bayesian_Gaussian_Mixture_Model\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/probability/tensorflow_probability/examples/jupyter_notebooks/Bayesian_Gaussian_Mixture_Model.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lin40yCC6eBo"
      },
      "source": [
        "In this colab we'll explore sampling from the posterior of a Bayesian Gaussian Mixture Model (BGMM) using only TensorFlow Probability primitives."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "eZs1ShikNBK2"
      },
      "source": [
        "## Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7JjokKMbk2hJ"
      },
      "source": [
        "For $k\\in\\{1,\\ldots, K\\}$ mixture components each of dimension $D$, we'd like to model $i\\in\\{1,\\ldots,N\\}$ iid samples using the following Bayesian Gaussian Mixture Model:\n",
        "\n",
        "$$\\begin{align*}\n",
        "\\theta &\\sim \\text{Dirichlet}(\\text{concentration}=\\alpha_0)\\\\\n",
        "\\mu_k &\\sim \\text{Normal}(\\text{loc}=\\mu_{0k}, \\text{scale}=I_D)\\\\\n",
        "T_k &\\sim \\text{Wishart}(\\text{df}=5, \\text{scale}=I_D)\\\\\n",
        "Z_i &\\sim \\text{Categorical}(\\text{probs}=\\theta)\\\\\n",
        "Y_i &\\sim \\text{Normal}(\\text{loc}=\\mu_{z_i}, \\text{scale}=T_{z_i}^{-1/2})\\\\\n",
        "\\end{align*}$$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "iySRABi0qZnQ"
      },
      "source": [
        "Note, the `scale` arguments all have `cholesky` semantics. We use this convention because it is that of TF Distributions (which itself uses this convention in part because it is computationally advantageous)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Y6X_Beihwzyi"
      },
      "source": [
        "Our goal is to generate samples from the posterior:\n",
        "\n",
        "$$p\\left(\\theta, \\{\\mu_k, T_k\\}_{k=1}^K \\Big| \\{y_i\\}_{i=1}^N, \\alpha_0, \\{\\mu_{ok}\\}_{k=1}^K\\right)$$\n",
        "\n",
        "Notice that $\\{Z_i\\}_{i=1}^N$ is not present--we're interested in only those random variables which don't scale with $N$.  (And luckily there's a TF distribution which handles marginalizing out $Z_i$.)\n",
        "\n",
        "It is not possible to directly sample from this distribution owing to a computationally intractable normalization term.\n",
        "\n",
        "[Metropolis-Hastings algorithms](https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm) are technique for for sampling from intractable-to-normalize distributions.\n",
        "\n",
        "TensorFlow Probability offers a number of MCMC options, including several based on Metropolis-Hastings. In this notebook, we'll use [Hamiltonian Monte Carlo](https://en.wikipedia.org/wiki/Hamiltonian_Monte_Carlo)  (`tfp.mcmc.HamiltonianMonteCarlo`). HMC is often a good choice because it can converge rapidly, samples the state space jointly (as opposed to coordinatewise), and leverages one of TF's virtues: automatic differentiation. That said, sampling from a BGMM posterior might actually be better done by other approaches, e.g., [Gibb's sampling](https://en.wikipedia.org/wiki/Gibbs_sampling)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "uswTWdgNu46j"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "\n",
        "\n",
        "import functools\n",
        "\n",
        "import matplotlib.pyplot as plt; plt.style.use('ggplot')\n",
        "import numpy as np\n",
        "import seaborn as sns; sns.set_context('notebook')\n",
        "\n",
        "import tensorflow.compat.v2 as tf\n",
        "tf.enable_v2_behavior()\n",
        "import tensorflow_probability as tfp\n",
        "\n",
        "tfd = tfp.distributions\n",
        "tfb = tfp.bijectors\n",
        "\n",
        "physical_devices = tf.config.experimental.list_physical_devices('GPU')\n",
        "if len(physical_devices) \u003e 0:\n",
        "  tf.config.experimental.set_memory_growth(physical_devices[0], True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Uj9uHZN2yUqz"
      },
      "source": [
        "Before actually building the model, we'll need to define a new type of distribution.  From the model specification above, its clear we're parameterizing the MVN with an inverse covariance matrix, i.e.,  [precision matrix](https://en.wikipedia.org/wiki/Precision_(statistics%29).  To accomplish this in TF,  we'll need to roll out our `Bijector`.  This `Bijector` will use the forward transformation:\n",
        "\n",
        "- `Y =  tf.linalg.triangular_solve((tf.linalg.matrix_transpose(chol_precision_tril), X, adjoint=True) + loc`.\n",
        "\n",
        "And the `log_prob` calculation is just the inverse, i.e.:\n",
        "\n",
        "- `X = tf.linalg.matmul(chol_precision_tril, X - loc, adjoint_a=True)`.\n",
        "\n",
        "Since all we need for HMC is `log_prob`, this means we avoid ever calling `tf.linalg.triangular_solve` (as would be the case for `tfd.MultivariateNormalTriL`). This is advantageous since `tf.linalg.matmul` is usually faster owing to better cache locality.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "nc4yy6vW-lC_"
      },
      "outputs": [],
      "source": [
        "class MVNCholPrecisionTriL(tfd.TransformedDistribution):\n",
        "  \"\"\"MVN from loc and (Cholesky) precision matrix.\"\"\"\n",
        "\n",
        "  def __init__(self, loc, chol_precision_tril, name=None):\n",
        "    super(MVNCholPrecisionTriL, self).__init__(\n",
        "        distribution=tfd.Independent(tfd.Normal(tf.zeros_like(loc),\n",
        "                                                scale=tf.ones_like(loc)),\n",
        "                                     reinterpreted_batch_ndims=1),\n",
        "        bijector=tfb.Chain([\n",
        "            tfb.Affine(shift=loc),\n",
        "            tfb.Invert(tfb.Affine(scale_tril=chol_precision_tril,\n",
        "                                  adjoint=True)),\n",
        "        ]),\n",
        "        name=name)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JDOkWhDQg4ZG"
      },
      "source": [
        "The `tfd.Independent` distribution turns independent draws of one distribution, into a multivariate distribution with statistically independent coordinates. In terms of computing `log_prob`, this \"meta-distribution\" manifests as a simple sum over the event dimension(s).\n",
        "\n",
        "Also notice that we took the `adjoint` (\"transpose\") of the scale matrix. This is because if precision is inverse covariance, i.e., $P=C^{-1}$ and if $C=AA^\\top$, then $P=BB^{\\top}$ where $B=A^{-\\top}$."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Pfkc8cmhh2Qz"
      },
      "source": [
        "Since this distribution is kind of tricky, let's quickly verify that our `MVNCholPrecisionTriL` works as we think it should."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 152
        },
        "colab_type": "code",
        "id": "GhqbjwlIh1Vn",
        "outputId": "3ea12c10-cb9b-4558-aedd-386b37adc909"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "true mean: [ 1. -1.]\n",
            "sample mean: [ 1.0002806 -1.000105 ]\n",
            "true cov:\n",
            " [[ 1.0625   -0.03125 ]\n",
            " [-0.03125   0.015625]]\n",
            "sample cov:\n",
            " [[ 1.0641273  -0.03126175]\n",
            " [-0.03126175  0.01559312]]\n"
          ]
        }
      ],
      "source": [
        "def compute_sample_stats(d, seed=42, n=int(1e6)):\n",
        "  x = d.sample(n, seed=seed)\n",
        "  sample_mean = tf.reduce_mean(x, axis=0, keepdims=True)\n",
        "  s = x - sample_mean\n",
        "  sample_cov = tf.linalg.matmul(s, s, adjoint_a=True) / tf.cast(n, s.dtype)\n",
        "  sample_scale = tf.linalg.cholesky(sample_cov)\n",
        "  sample_mean = sample_mean[0]\n",
        "  return [\n",
        "      sample_mean,\n",
        "      sample_cov,\n",
        "      sample_scale,\n",
        "  ]\n",
        "\n",
        "dtype = np.float32\n",
        "true_loc = np.array([1., -1.], dtype=dtype)\n",
        "true_chol_precision = np.array([[1., 0.],\n",
        "                                [2., 8.]],\n",
        "                               dtype=dtype)\n",
        "true_precision = np.matmul(true_chol_precision, true_chol_precision.T)\n",
        "true_cov = np.linalg.inv(true_precision)\n",
        "\n",
        "d = MVNCholPrecisionTriL(\n",
        "    loc=true_loc,\n",
        "    chol_precision_tril=true_chol_precision)\n",
        "\n",
        "[sample_mean, sample_cov, sample_scale] = [\n",
        "    t.numpy() for t in compute_sample_stats(d)]\n",
        "\n",
        "print('true mean:', true_loc)\n",
        "print('sample mean:', sample_mean)\n",
        "print('true cov:\\n', true_cov)\n",
        "print('sample cov:\\n', sample_cov)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "N60z8scN1v6E"
      },
      "source": [
        "Since the sample mean and covariance are close to the true mean and covariance, it seems like the distribution is correctly implemented. Now, we'll use `MVNCholPrecisionTriL` `tfp.distributions.JointDistributionNamed` to specify the BGMM model. For the observational model, we'll use `tfd.MixtureSameFamily` to automatically integrate out the $\\{Z_i\\}_{i=1}^N$ draws."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xhzxySDjL2-S"
      },
      "outputs": [],
      "source": [
        "dtype = np.float64\n",
        "dims = 2\n",
        "components = 3\n",
        "num_samples = 1000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xAOmHhZ7LzDQ"
      },
      "outputs": [],
      "source": [
        "bgmm = tfd.JointDistributionNamed(dict(\n",
        "  mix_probs=tfd.Dirichlet(\n",
        "    concentration=np.ones(components, dtype) / 10.),\n",
        "  loc=tfd.Independent(\n",
        "    tfd.Normal(\n",
        "        loc=np.stack([\n",
        "            -np.ones(dims, dtype),\n",
        "            np.zeros(dims, dtype),\n",
        "            np.ones(dims, dtype),\n",
        "        ]),\n",
        "        scale=tf.ones([components, dims], dtype)),\n",
        "    reinterpreted_batch_ndims=2),\n",
        "  precision=tfd.Independent(\n",
        "    tfd.WishartTriL(\n",
        "        df=5,\n",
        "        scale_tril=np.stack([np.eye(dims, dtype=dtype)]*components),\n",
        "        input_output_cholesky=True),\n",
        "    reinterpreted_batch_ndims=1),\n",
        "  s=lambda mix_probs, loc, precision: tfd.Sample(tfd.MixtureSameFamily(\n",
        "      mixture_distribution=tfd.Categorical(probs=mix_probs),\n",
        "      components_distribution=MVNCholPrecisionTriL(\n",
        "          loc=loc,\n",
        "          chol_precision_tril=precision)),\n",
        "      sample_shape=num_samples)\n",
        "))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CpLnRJr2TXYD"
      },
      "outputs": [],
      "source": [
        "def joint_log_prob(observations, mix_probs, loc, chol_precision):\n",
        "  \"\"\"BGMM with priors: loc=Normal, precision=Inverse-Wishart, mix=Dirichlet.\n",
        "\n",
        "  Args:\n",
        "    observations: `[n, d]`-shaped `Tensor` representing Bayesian Gaussian\n",
        "      Mixture model draws. Each sample is a length-`d` vector.\n",
        "    mix_probs: `[K]`-shaped `Tensor` representing random draw from\n",
        "      `Dirichlet` prior.\n",
        "    loc: `[K, d]`-shaped `Tensor` representing the location parameter of the\n",
        "      `K` components.\n",
        "    chol_precision: `[K, d, d]`-shaped `Tensor` representing `K` lower\n",
        "      triangular `cholesky(Precision)` matrices, each being sampled from\n",
        "      a Wishart distribution.\n",
        "\n",
        "  Returns:\n",
        "    log_prob: `Tensor` representing joint log-density over all inputs.\n",
        "  \"\"\"\n",
        "  return bgmm.log_prob(\n",
        "      mix_probs=mix_probs, loc=loc, precision=chol_precision, s=observations)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7jTMXdymV1QJ"
      },
      "source": [
        "## Generate \"Training\" Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rl4brz3G3pS7"
      },
      "source": [
        "For this demo, we'll sample some random data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1AJZAtwXV8RQ"
      },
      "outputs": [],
      "source": [
        "true_loc = np.array([[-2., -2],\n",
        "                     [0, 0],\n",
        "                     [2, 2]], dtype)\n",
        "random = np.random.RandomState(seed=43)\n",
        "\n",
        "true_hidden_component = random.randint(0, components, num_samples)\n",
        "observations = (true_loc[true_hidden_component] +\n",
        "                random.randn(num_samples, dims).astype(dtype))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zVOvMh7MV37A"
      },
      "source": [
        "## Bayesian Inference using HMC"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cdN3iKFT32Jp"
      },
      "source": [
        "Now that we've used TFD to specify our model and obtained some observed data, we have all the necessary pieces to run HMC.\n",
        "\n",
        "To do this, we'll use a [partial application](https://en.wikipedia.org/wiki/Partial_application) to \"pin down\" the things we don't want to sample. In this case that means we need only pin down `observations`. (The hyper-parameters are already baked in to the prior distributions and not part of the `joint_log_prob` function signature.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tVoaDFSf7L_j"
      },
      "outputs": [],
      "source": [
        "unnormalized_posterior_log_prob = functools.partial(joint_log_prob, observations)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "a0OMIWIYeMmQ"
      },
      "outputs": [],
      "source": [
        "initial_state = [\n",
        "    tf.fill([components],\n",
        "            value=np.array(1. / components, dtype),\n",
        "            name='mix_probs'),\n",
        "    tf.constant(np.array([[-2., -2],\n",
        "                          [0, 0],\n",
        "                          [2, 2]], dtype),\n",
        "                name='loc'),\n",
        "    tf.linalg.eye(dims, batch_shape=[components], dtype=dtype, name='chol_precision'),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TVpiT3LLyfcO"
      },
      "source": [
        "### Unconstrained Representation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JS8XOsxiyiBV"
      },
      "source": [
        "Hamiltonian Monte Carlo (HMC) requires the target log-probability function be differentiable with respect to its arguments.  Furthermore, HMC can exhibit dramatically higher statistical efficiency if the state-space is unconstrained.\n",
        "\n",
        "This means we'll have to work out two main issues when sampling from the BGMM posterior:\n",
        "\n",
        "1. $\\theta$ represents a discrete probability vector, i.e., must be such that $\\sum_{k=1}^K \\theta_k = 1$ and $\\theta_k\u003e0$.\n",
        "2. $T_k$ represents an inverse covariance matrix, i.e., must be such that $T_k \\succ 0$, i.e., is [positive definite](https://en.wikipedia.org/wiki/Positive-definite_matrix).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Vt9SXJzO0Cks"
      },
      "source": [
        "To address this requirement we'll need to:\n",
        "\n",
        "1. transform the constrained variables to an unconstrained space\n",
        "2. run the MCMC in unconstrained space\n",
        "3. transform the unconstrained variables back to the constrained space.\n",
        "\n",
        "As with `MVNCholPrecisionTriL`, we'll use [`Bijector`s](https://www.tensorflow.org/api_docs/python/tf/distributions/bijectors/Bijector) to transform random variables to unconstrained space.\n",
        "\n",
        "- The [`Dirichlet`](https://en.wikipedia.org/wiki/Dirichlet_distribution) is transformed to unconstrained space via the [softmax function](https://en.wikipedia.org/wiki/Softmax_function).\n",
        "\n",
        "- Our precision random variable is a distribution over postive semidefinite matrices. To unconstrain these we'll use the `FillTriangular` and `TransformDiagonal` bijectors.  These convert vectors to lower-triangular matrices and ensure the diagonal is positive. The former is useful because it enables sampling only $d(d+1)/2$ floats rather than $d^2$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "_atEQrDR7JvG"
      },
      "outputs": [],
      "source": [
        "unconstraining_bijectors = [\n",
        "    tfb.SoftmaxCentered(),\n",
        "    tfb.Identity(),\n",
        "    tfb.Chain([\n",
        "        tfb.TransformDiagonal(tfb.Softplus()),\n",
        "        tfb.FillTriangular(),\n",
        "    ])]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "0zq6QJJ-NSPJ"
      },
      "outputs": [],
      "source": [
        "@tf.function(autograph=False)\n",
        "def sample():\n",
        "  return tfp.mcmc.sample_chain(\n",
        "    num_results=2000,\n",
        "    num_burnin_steps=500,\n",
        "    current_state=initial_state,\n",
        "    kernel=tfp.mcmc.SimpleStepSizeAdaptation(\n",
        "        tfp.mcmc.TransformedTransitionKernel(\n",
        "            inner_kernel=tfp.mcmc.HamiltonianMonteCarlo(\n",
        "                target_log_prob_fn=unnormalized_posterior_log_prob,\n",
        "                 step_size=0.065,\n",
        "                 num_leapfrog_steps=5),\n",
        "            bijector=unconstraining_bijectors),\n",
        "         num_adaptation_steps=400),\n",
        "    trace_fn=lambda _, pkr: pkr.inner_results.inner_results.is_accepted)\n",
        "\n",
        "[mix_probs, loc, chol_precision], is_accepted = sample()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QLEz96mg6fpZ"
      },
      "source": [
        "We'll now execute the chain and print the posterior means."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "_ceX1A3-ZFiN"
      },
      "outputs": [],
      "source": [
        "acceptance_rate = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32)).numpy()\n",
        "mean_mix_probs = tf.reduce_mean(mix_probs, axis=0).numpy()\n",
        "mean_loc = tf.reduce_mean(loc, axis=0).numpy()\n",
        "mean_chol_precision = tf.reduce_mean(chol_precision, axis=0).numpy()\n",
        "precision = tf.linalg.matmul(chol_precision, chol_precision, transpose_b=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 270
        },
        "colab_type": "code",
        "id": "bqJ6RSJxegC6",
        "outputId": "e0867545-0509-4077-d89d-74e1d5280062"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "acceptance_rate: 0.5305\n",
            "avg mix probs: [0.25248723 0.60729516 0.1402176 ]\n",
            "avg loc:\n",
            " [[-1.96466753 -2.12047249]\n",
            " [ 0.27628865  0.22944732]\n",
            " [ 2.06461244  2.54216122]]\n",
            "avg chol(precision):\n",
            " [[[ 1.05105032  0.        ]\n",
            "  [ 0.12699955  1.06553113]]\n",
            "\n",
            " [[ 0.76058015  0.        ]\n",
            "  [-0.50332767  0.77947431]]\n",
            "\n",
            " [[ 1.22770457  0.        ]\n",
            "  [ 0.70670027  1.50914164]]]\n"
          ]
        }
      ],
      "source": [
        "print('acceptance_rate:', acceptance_rate)\n",
        "print('avg mix probs:', mean_mix_probs)\n",
        "print('avg loc:\\n', mean_loc)\n",
        "print('avg chol(precision):\\n', mean_chol_precision)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 289
        },
        "colab_type": "code",
        "id": "zFOU0j9kPdUy",
        "outputId": "17f4ce0c-24c3-4cf4-ebe8-b932caac7ba4"
      },
      "outputs": [
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEQCAYAAABLMTQcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAG4dJREFUeJzt3Xts1fX9x/HXuRTanrbQcsqpxy4C4iXoJFRATWEzJArK\ngktWzP6YDpbNyIrzHpyTS+d0Y/7QKBbYENHBZGCNaOzMXDL/oCBWyhJlmbuJGigczmmp0gNSaM/v\nj3pOzzk9p6el355z+jnPR2Lo+Z5vP99PP5HXefP+XmoLhUIhAQCMYs/0BAAA1iPcAcBAhDsAGIhw\nBwADEe4AYCDCHQAMRLgjZ7zyyiuqrq5WVVWVvvjii5j3jh49qiuvvFI9PT0jOod58+bpvffeG9Fj\nABLhjhEUH2SNjY2aPXu2Dhw4EAnTqqoqVVVVac6cObr77ru1b9++fmNMnz5dVVVVmjFjhqqqqvSr\nX/1qyHM5f/681q5dq61bt+rgwYMaN25cv31sNtvQf0ggSzkzPQHkhtdff11r167V5s2bNX36dB09\nelQ2m00tLS2y2Wxqa2tTY2OjamtrtXr1an33u9+NfO/vfvc7XX/99cM6fiAQUFdXly699NLh/igj\npru7Ww6HI9PTgCGo3DHidu7cqd/+9rd68cUXNX369Jj3wjdIT5gwQXfeeafuuecePfXUUwn3SaWr\nq0tPPPGE5s6dq29961t68sknde7cOX366ae65ZZbJEmzZs3SkiVLUo514sQJLVu2TNddd53mz5+v\nV199NfJeT0+PNm3apJtuuknXXnutvve978nn8yUcZ/fu3Zo3b56uv/56bdq0Kea9559/Xj/72c/0\n8MMPa+bMmXr99df14Ycf6vvf/75mzZqluXPn6vHHH9f58+clSevXr4/8q+X8+fOaMWOG/u///k+S\ndPbsWV1zzTU6deqUurq69PDDD+u6667TrFmztHjxYrW3tw9qDWEOwh0jaseOHVq/fr1efvllTZs2\nLeX+N910k9ra2vTJJ58M+VgbN27URx99pDfffFNvvPGGPvzwQ23cuFGTJk3SW2+9JUlqaWnRSy+9\nlHKsBx54QF6vV01NTXr22Wf19NNPa//+/ZKkF198UX/+85/1wgsvqKWlRU8++aTy8/P7jfHf//5X\ndXV1euqpp7Rnzx51dHToxIkTMfv87W9/0y233KIDBw5o0aJFcjqdevTRR9Xc3KydO3dq//79euWV\nVyT1fjA1NzdLkj766CO53W598MEHkqSDBw9qypQpKi4u1uuvv67Ozk7t2bNHzc3Nqqur09ixY4e8\nnhjdCHeMqH379mn69Om6/PLLB7W/x+ORpJgTnrW1tZo9e7ZmzZql2bNnx1TR0d566y3V1taqtLRU\npaWlWr58uXbv3i2pr/ofzL8Cjh07pr///e966KGHlJeXpyuvvFKLFy/WG2+8IUlqaGjQ/fffr0su\nuUSSdMUVVyTs4f/lL3/RvHnzdO211yovL0/33ntvv31mzJihefPmSZLGjBmjadOm6ZprrpHNZpPX\n69Xtt98eCfAZM2bos88+0xdffKEPPvhANTU18vl8OnPmjA4cOKBZs2ZJkpxOpzo6OnT48GHZbDZN\nmzZNLpcr5c8Ns9Bzx4iqq6vThg0b9Oijj+rJJ59MuX+4vTF+/PjItg0bNgyq537ixAl5vd7Ia6/X\nK7/fL2loJ0v9fr/GjRungoKCmLH+8Y9/SJKOHz+ub3zjG4OaT0VFReR1QUFBzM8lKeZ9Sfr000/1\nm9/8RocOHdJXX32l7u5uXXXVVZKksWPH6uqrr1Zzc7MOHDigZcuW6eOPP1ZLS4uam5t15513SpJu\nu+02HT9+XA888IBOnTqlRYsW6f7776efn2Oo3DGiysrK9NJLL6mlpUVr1qxJuf8777wjt9utyZMn\nR7YNtufu8Xh09OjRyOvW1lZNnDhxyHOeOHGivvjiC50+fTqy7dixY5GxKioq9Pnnn6ccp7y8XMeP\nH4+8PnPmjDo6OmL2if/QWbNmjaZMmaK//vWvOnDggO67776Yn3/mzJnav3+//vnPf+qb3/ymZs6c\nqaamJh06dEgzZ86U1Fu519bWqrGxUX/605/07rvvRv4Fg9xBuGPElZeX6+WXX1ZTU5N+/etfR7aH\nQqFIcLW1tWn79u3asGGDHnzwwQs6zq233qqNGzeqvb1d7e3t2rBhg2677baY4w0k/H5FRYVmzJih\np59+Wl1dXfr444/V0NCgRYsWSZIWL16sZ599Vp999pkk6V//+le/6+YlacGCBXr33Xd18OBBnTt3\nTs8991zKnyEYDKqoqEgFBQX63//+px07dsS8P3v2bO3evVtTp06V0+nUddddp1dffVWVlZUqLS2V\nJL3//vv697//rZ6eHhUWFsrpdFK156Cca8sEg0E1NjZq4cKFOd2HTMc6RFelFRUVeumll3THHXco\nPz9ft99+u2w2m2bNmqVQKKTCwkJdffXVeu6551RdXR0zzrJly2S399Uh1dXVWr9+fb/j/fSnP1Uw\nGNSiRYtks9l0yy236O677044n2jhtYh+f926dVq9erXmzp2rcePG6d5779UNN9wgSVq6dKnOnTun\nH/3oR+ro6NCUKVP0/PPP9+u7T506VatWrdKDDz6oM2fOaOnSpZFzCsmsWLFCK1eu1AsvvKBp06Zp\n4cKFkRO5Um/f/ezZs5H++tSpU5Wfnx95LfVe9rl69Wr5fD65XC7deuutkQ+mVPj70cuIdQjlGJ/P\nF1q8eHHI5/NleioZxTr0YS36sBa9TFgH2jIAYCDL2jJPPfWU/H6/bDab8vPztXTpUk2aNMmq4QEA\nQ2BZuC9fvjxy6diBAwe0ceNGrV271qrhAQBDYFlbJvqa4GAwGHMCLJvY7XaVl5dn7fzShXXow1r0\nYS16mbAOtlBokBcRD8KmTZv04YcfSpIeffRRVVZWxrwfDAYVDAZjtjmdTpWVlVk1BQDIKe3t7ZHn\nD4W5XC5rwz1sz549ampq0s9//vOY7bt27VJDQ0PMNo/Hk/CyNgBAavfcc0+/B9fV1NSMTLhL0g9+\n8ANt2rRJRUVFkW2JKne73S632y2fz6fu7u6RmEo/Xq9Xra2taTlWNmMd+rAWfViLXtm+Dg6HQx6P\nR4FAoN8vmXG5XNacUP3qq68UDAY1YcIESb0nVIuLi2OCPXzAUXtDAABkIbfbnXC7JeF+9uzZyK3a\nNptNxcXFWrFihRVDAwAugCXhPm7cOD3xxBNWDAUAsMDovc4HAJAU4Q4ABiLcAcBAOffIXwC57ezp\notQ7STr83y8lFWlsYefITmiEEO4AcsZgg32g7xktYU+4A8hZwc7zqXeK4ipyjpqwJ9wB5IT4UI4O\n9lQh7ypyDrBf/38NZEPgE+4Aclo4sIOd5xK+7yrKSxr+riJn5L3wB4DU90GSyZAn3AHkhLGFnRfU\nc08W+tHCIR8d8JnGpZAAckZ0JR0O4r4/8wb83tOd5/v9Fw7+ofbu0yF7PmYAIA2GWsGfjmvbBINf\nt2FcTgU7z+l00XkVxlXs4Q+Ms6czdykl4Q4AceIDXeoL9USvwyEfLHLG/Asg/uqadAY94Q4AUQYT\n7NESvRfsPBc5ERvbh09f0BPuABClsMip053nvw7n1CdTowU7z0WC/XRnb7sm/DpauKIfyYAn3AHk\nnHDfPdHVLcHOc/0C3uXq2y9ZFT9QdS8pSSU/crhaBkBOGlvYGamco6+cCVfZhUmuonG5nDH/Zavs\nnRkApNFAFXW4ko8WbtkkC/j4D4lUx7Aa4Q4gp8VfGtlbvcfelBT8ukXT+3XyPnz8JZF9Yya+hn4k\n++6EO4CcFx2w4aBPVmX3BX/sCddUN0ElG2+kAp5wB4AoiW5yig/mC33cQDoDnnAHgDiDuYs18ZU2\n5wd8P50IdwBIYPLUErW2tg56/2SXVg4GbRkAyFKJAjpZ9Z+OxxAQ7gAwQjL5PHduYgIAAxHuAGAg\nwh0ADES4A4CBCHcAMJAlV8t0dnZq/fr1OnHihJxOpyoqKnTXXXepuLjYiuEBAENk2aWQt912m6ZN\nmyZJ2r59u/74xz/q7rvvtmp4AMAQWNKWKSoqigS7JF122WUKBAJWDA0AuACW99xDoZDeeecdzZw5\n0+qhAQCDZPkdqlu2bFFBQYEWLFjQ771gMKhgMBizzW63y+12Wz0NAMgJgUBAPT09MdtcLpdsoVAo\nZNVBtm3bps8//1yPPPKIHA5Hv/d37dqlhoaGmG3l5eWqr6+3agoAkFNqa2vl9/tjttXU1FgX7jt2\n7NB//vMfPfLIIxozZkzCfQaq3H0+n7q7u62YSkper3dIT3szFevQh7Xow1r0yvZ1cDgc8ng8SSt3\nS9oyR44c0e7du+X1evXYY49JkiZOnKiHHnqo3wFdLpcVhwQASEnb2paEe2VlpXbu3GnFUAAAC3CH\nKgAYiHAHAAMR7gBgIMIdAAxEuAOAgQh3ADAQ4Q4ABiLcAcBAhDsAGIhwBwADEe4AYCDCHQAMRLgD\ngIEIdwAwEOEOAAYi3AHAQIQ7ABiIcAcAAxHuAGAgwh0ADES4A4CBCHcAMBDhDgAGItwBwECEOwAY\niHAHAAMR7gBgIMIdAAxEuAOAgQh3ADCQ06qBtm3bpvfff19+v1/r1q1TZWWlVUMDAIbIssp99uzZ\n+uUvf6ny8nKrhgQAXCDLKvcrrrhCkhQKhawaEgBwgSwL98EIBoMKBoMx2+x2u9xudzqnAQDGCAQC\n6unpidnmcrnSG+6NjY1qaGiI2VZeXq76+np5PJ50TkVerzetx8tWrEMf1qIPa9FrNKzD6tWr5ff7\nY7bV1NSkN9wXLlyoG2+8MWab3d7b9vf5fOru7k7LPLxer1pbW9NyrGzGOvRhLfqwFr2yfR0cDoc8\nHo/q6uoyX7m7XC65XK50HhIAjJasrW1ZuG/dulXNzc3q6OjQ448/rqKiIq1bt86q4QEAQ2BZuC9d\nulRLly61ajgAwDBwhyoAGIhwBwADEe4AYCDCHQAMRLgDgIEIdwAwEOEOAAYi3AHAQIQ7ABgorc+W\nATB47bailPuUhTrTMBOMRlTuQJZptxUNKtjD+wKJULkDWSI6qNtPn0u6X1lhXjqmg1GOcAcyKL7y\nHijUo/ch4JEK4Q5kQLIqveN0V9LvGV84ZkTnBLMQ7kCaDCbQT57pX7mXFiSv0jmhimQId2CEJQr1\ncKAnCvNUaMlgMAh3YISEQz1ZoJ9M0oIpHWT7haodAyHcAQslq9LjA70j7sTp+CTVeLglE+63h6t2\ngh2pEO6ABQYT6tGBfvJMl0oL+lfo4aqdUMdwEe7AMMSHenTrJbpKP3mmfwsmacAX5PULdYlgx9AQ\n7sAFGKifHl2lh0M9umqPb8GEX5cWjkkY7IQ6LgThDgxBolBP1Ho5eaarX189XmnBGI0vzIuEutTb\nhiHUYQXCHRiERO2XRK2XVIEeqdLjgp1qHVYj3IEUoqv16PbL4UAwJtDbE1zaWBZ1WeP4wjyqdaQN\n4Q4kMVC1Hl2pt5/uUkeCm5HGR91ZmijYCXWMJMIdSCBZb32wwR5tfGGeJk9wJWzBEOoYKYQ7ECe+\nDRN/wjTRZY3RxhfkqaxwDNU6MopwB742ULXe+7p/hV5WOCamry4lb8H07k+wIz0Id+S8dluR2o99\nOaTHBUix16uHb0bimnVkC8vC/dixY6qvr1dnZ6eKi4tVW1uriooKq4YHRsSFVOvxd5XGB3rvNkId\nmWVZuG/evFkLFizQnDlztGfPHv3+97/XqlWrrBoesFSqO0wHElOx8ywYZClLwv3LL7/U4cOHVV1d\nLUmqrq7Wiy++qFOnTqm4uNiKQwCWiA91aXDPVh9MoEuEOrKHJeEeCARUVlYmm80mSbLb7SotLVVb\nWxvhjqxwob8wYyhVukSoI3uk9YRqMBhUMBiM2Wa32+V2u9M5DeSYoVbryX5ZBqGObBQIBNTT0xOz\nzeVyWRPubrdb7e3tCoVCstls6unp0cmTJzVhwoSY/RobG9XQ0BCzrby8XPX19fJ4PFZMZdC8Xm9a\nj5etTF+HQ8e+7LdtoF9CHW+g1oskXX1Ryddflcgkpv9/MVijYR1Wr14tv98fs62mpsaacC8pKdGk\nSZPU1NSkuXPnqqmpSZMnT+7Xklm4cKFuvPHGmG12u12S5PP51N3dbcV0UvJ6vWptbU3LsbKZ6esQ\n3YoZitKYxwYMXKW3tppXsZv+/8VgZfs6OBwOeTwe1dXVjVzlLkk/+clPVF9fr9dee00ul0vLly/v\nt4/L5ZLL5bLqkMAFGV84JlK9Rwd5/D4SrRdkv2RtbcvC3ev16oknnrBqOGDYykKdSav38Un66pHv\nJdQxynGHKnJGOLCjT6xGh3i//Ql1jGKEO4yWqHpPFOhXX1SS1f1VYKjsmZ4AMNIGqsDLQp1U6DAS\nlTtyAgGOXEPlDgAGItwBwECEOwAYiHAHAAMR7gBgIMIdAAxEuAOAgQh3ADAQ4Q4ABiLcAcBAhDsA\nGIhwBwADEe4AYCDCHQAMRLgDgIEIdwAwEOEOAAYi3AHAQIQ7ABiIcAcAAxHuAGAgwh0ADES4A4CB\nCHcAMBDhDgAGItwBwEDO4Q6wZ88evfnmmzpy5IiWLFmi+fPnWzEvAMAwDLtynzx5su677z7NmTPH\nivkAACww7Mq9srJSkmSz2YY9GQCANYYd7kMRDAYVDAZjttntdrnd7nROAwCMEQgE1NPTE7PN5XKl\nDvcVK1aora0tZlsoFJLNZtPmzZuHVLE3NjaqoaEhZlt5ebnq6+vl8XgGPY4VvF5vWo+XrViHPqxF\nH9ai12hYh9WrV8vv98dsq6mpkS0UCoWsOMCGDRt06aWXDnhCdaDK3efzqbu724qppOT1etXa2pqW\nY2Uz1qEPa9GHteiV7evgcDjk8XguvHIfilSfEy6XSy6Xy8pDAkBOS9bWHvbVMnv37tWyZcu0f/9+\n7dq1S8uWLdPRo0eHOywAYBiGXblXV1erurrairkAACzCHaoAYCDCHQAMRLgDgIEIdwAwEOEOAAYi\n3AHAQIQ7ABiIcAcAAxHuAGAgwh0ADES4A4CBCHcAMBDhDgAGItwBwECEOwAYiHAHAAMR7gBgIMId\nAAxEuAOAgQh3ADAQ4Q4ABiLcAcBAhDsAGIhwBwADEe4AYCDCHQAMRLgDgIEIdwAwkDPTE8iErw6+\np7IE29srLkn7XABgJAw73Lds2aJDhw4pLy9P+fn5WrJkiaZMmWLF3CxVdvyzmNfdgRMxrx3uif32\nCSP0AYw2ww73GTNmaOnSpbLb7Tp48KCeeeYZrV+/3oq5WSJVqMdvd7gnJh2DkAcwWgw73KuqqiJf\nX3755Wpvbx/ukJYJh3J0oJ/3+4Y0RqKwB4BsZ2nP/e23344J+0wZKNS7AwOHe3fAJ4fbI2e5JzJG\nOODLjn9G9Q5gVEgZ7itWrFBbW1vMtlAoJJvNps2bN8tms0mS9u7dq3379qmuri7pWMFgUMFgMGab\n3W6X2+2+kLknFB/s5/2+foE+UMA73J7I9xHwALJdIBBQT09PzDaXyyVbKBQKDXfw5uZmbd++XatW\nrRowqHft2qWGhoaYbeXl5aqvrx/W8b86+F7M6/hgjw7zZMEeDnWH2xP5Ohzuvdtj2zP5VTcMa84A\nYIXa2lr5/f6YbTU1NcMP95aWFm3dulUrV66Ux+MZcN+BKnefz6fu7u4hHz/6hGmyNkx0wHf7+/Zx\nlPcFdnS4R/8ZDvj4cDehevd6vWptbc30NLICa9GHteiV7evgcDjk8XiSVu7D7rlv3LhReXl5evrp\npyPtmpUrV6qoqKjfvi6XSy6Xa7iHjIgP9viTpf2qdn/slTLd/hORgA/32sN/AsBokKxbMuxwf+GF\nF4Y7xJANdHljqhOm8RIFfPzXADDajLrHDyS60SjV5YpDDfzeMT1Gt2QAmG1UPX4g2R2kFyq65550\nH4IdwCg06ir3gTjLPQlbKQm3DRDs4aqdYAcwWo2qyn2wwidGw18n2p7s+5IFOwCMJkaEe7LnxUSL\nvsQxUcBHB3siVO0ARpNRFe7tFZek7Ls7yz0pnx8T36ZJdU07AIw2oyrcpdgKOhz0DvfEmOo9WfU9\nkIGCnaodwGgz6sI9WnQlP5hqO1lFzyWPAEwzqsNd6h/AQwn7wY4JAKPNqA/3eImCOVngJ/oFHQQ7\nABMYF+6JhAM7/mQsbRgApsqJcA8Lh3e2P+0NAIbLqDtUAQC9CHcAMBDhDgAGItwBwECEOwAYiHAH\nAAMR7gBgIMIdAAxEuAOAgQh3ADAQ4Q4ABiLcAcBAWfPgMLs9vZ8zDocjrcfLVqxDH9aiD2vRK5vX\nIVVm2kKhUChNcwEApEnOtWUCgYBqa2sVCAQyPZWMYh36sBZ9WIteJqxDzoV7T0+P/H6/enp6Mj2V\njGId+rAWfViLXiasQ86FOwDkAsIdAAxEuAOAgRxr1qxZk+lJpFteXp6uuuoqjRkzJtNTySjWoQ9r\n0Ye16DXa14FLIQHAQLRlAMBAhDsAGChrHj+Qblu2bNGhQ4eUl5en/Px8LVmyRFOmTMn0tNJuz549\nevPNN3XkyBEtWbJE8+fPz/SU0urYsWOqr69XZ2eniouLVVtbq4qKikxPK+22bdum999/X36/X+vW\nrVNlZWWmp5QxnZ2dWr9+vU6cOCGn06mKigrdddddKi4uzvTUhiQnT6iG/fCHP9TNN9+scePGaePG\njbr11lszPaW0czgcuuGGG3T69GmVlZVp6tSpmZ5SWj3zzDO66aabdNdddykvL0+vvfaavv3tb2d6\nWmmXn5+v73znO2pubtacOXNUUlKS6SllTFdXly666CLdcccduvnmm/XJJ5+opaVFM2fOzPTUhiRn\n2zJVVVWRB+9cfvnlam9vz/CMMqOyslIXX3yxbDZbpqeSdl9++aUOHz6s6upqSVJ1dbUOHz6sU6dO\nZXhm6XfFFVeorKxMXF8hFRUVadq0aZHXl1122ah8DEHOhnu0t99+W1VVVZmeBtIsEAiorKws8sFm\nt9tVWlqqtra2DM8M2SIUCumdd94ZdVW7ZHDPfcWKFf3+koZCIdlsNm3evDnyF3rv3r3at2+f6urq\nMjHNETfYdQDQ35YtW1RQUKAFCxZkeipDZmy4r127NuU+zc3N2rlzp1atWmVsj3Ew65Cr3G632tvb\nIx92PT09OnnypCZMmJDpqSELbNu2TT6fT4888kimp3JBcrYt09LSoj/84Q/6xS9+IbfbnenpZIVc\n67eWlJRo0qRJampqkiQ1NTVp8uTJo+6qCFhvx44dOnz4sB5++OGs/oUdA8nZO1R//OMfKy8vTyUl\nJZHKbeXKlSoqKsr01NJq79692r59u4LBoJxOp8aOHavHHntMF198caanlhatra2qr69XMBiUy+XS\n8uXLddFFF2V6Wmm3detWNTc3q6OjQyUlJSoqKtK6desyPa2MOHLkiB588EF5vV7l5eVJkiZOnKiH\nHnoowzMbmpwNdwAwWc62ZQDAZIQ7ABiIcAcAAxHuAGAgwh0ADES4A4CBCHcAMBDhDgAG+n+UFdHi\na8/mXgAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0xc7fad79c0d0\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "loc_ = loc.numpy()\n",
        "ax = sns.kdeplot(loc_[:,0,0], loc_[:,0,1], shade=True, shade_lowest=False)\n",
        "ax = sns.kdeplot(loc_[:,1,0], loc_[:,1,1], shade=True, shade_lowest=False)\n",
        "ax = sns.kdeplot(loc_[:,2,0], loc_[:,2,1], shade=True, shade_lowest=False)\n",
        "plt.title('KDE of loc draws');"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NmfNIM1c6mwc"
      },
      "source": [
        "## Conclusion"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "t8LeIeMn6ot4"
      },
      "source": [
        "This simple colab demonstrated how TensorFlow Probability primitives can be used to build hierarchical Bayesian mixture models."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Bayesian Gaussian Mixture Model",
      "private_outputs": false,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
